#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2019-05-14 17:42
# @Author  : yuxuecheng
# @Contact : yuxuecheng@xinluomed.com
# @Site    : 
# @File    : tf_function.py
# @Software: PyCharm
# @Description tf.function example wechat url：https://mp.weixin.qq.com/s/fNh0bgBMdLE99K1CLfX3ig

import tensorflow as tf

# 一个函数相当于一个操作
@tf.function
def add(a, b):
    return a + b

# 可以在函数中使用函数
@tf.function
def dense_layer(x, w, b):
    return add(tf.matmul(x, w), b)


print(add(tf.ones([2, 2]), tf.ones([2, 2])))  # [[2., 2.],[2., 2.]])
"""
tf.Tensor(
[[2. 2.]
 [2. 2.]], shape=(2, 2), dtype=float32)
"""

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
    result = add(v, 1.0)

print(tape.gradient(result, v))

print(dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2])))

